package com.ddwanglife.ml

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature._
import org.apache.spark.ml.classification.{DecisionTreeClassifier, LogisticRegression}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.Row

object PipelinesDemo {
  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local").appName("my App Name").getOrCreate()
    import spark.implicits._
    // 构建训练数据集
    val training = spark.createDataFrame(Seq(
             (0L, "a b c d e spark", 1.0),
             (1L, "b d", 0.0),
             (2L, "spark f g h", 1.0),
             (3L, "hadoop mapreduce", 0.0)
             )).toDF("id", "text", "label")
    /*
      training.show()
      +---+----------------+-----+
      | id|            text|label|
      +---+----------------+-----+
      |  0| a b c d e spark|  1.0|
      |  1|             b d|  0.0|
      |  2|     spark f g h|  1.0|
      |  3|hadoop mapreduce|  0.0|
      +---+----------------+-----+
     */

    val tokenizer = new Tokenizer().
             setInputCol("text").
             setOutputCol("words")

    /*
    tokenizer.transform(training).show()
    +---+----------------+-----+--------------------+
    | id|            text|label|               words|
    +---+----------------+-----+--------------------+
    |  0| a b c d e spark|  1.0|[a, b, c, d, e, s...|
    |  1|             b d|  0.0|              [b, d]|
    |  2|     spark f g h|  1.0|    [spark, f, g, h]|
    |  3|hadoop mapreduce|  0.0| [hadoop, mapreduce]|
    +---+----------------+-----+--------------------+
     */

    val hashingTF = new HashingTF().
      setNumFeatures(1000).
      setInputCol(tokenizer.getOutputCol).
      setOutputCol("features")
     /*
      hashingTF.transform(tokenizer.transform(training)).show()
      +---+----------------+-----+--------------------+--------------------+
      | id|            text|label|               words|            features|
      +---+----------------+-----+--------------------+--------------------+
      |  0| a b c d e spark|  1.0|[a, b, c, d, e, s...|(1000,[94,105,170...|
      |  1|             b d|  0.0|              [b, d]|(1000,[94,361],[1...|
      |  2|     spark f g h|  1.0|    [spark, f, g, h]|(1000,[105,248,28...|
      |  3|hadoop mapreduce|  0.0| [hadoop, mapreduce]|(1000,[181,953],[...|
      +---+----------------+-----+--------------------+--------------------+
      */



    val lr = new LogisticRegression().
             setMaxIter(10).
             setRegParam(0.01)

    val pipeline = new Pipeline().
             setStages(Array(tokenizer, hashingTF, lr))

    val model = pipeline.fit(training)

    val test = spark.createDataFrame(Seq(
             (4L, "spark i j k"),
             (5L, "l m n"),
             (6L, "spark a"),
             (7L, "apache hadoop")
             )).toDF("id", "text")

    model.transform(test).
             select("id", "text", "probability", "prediction").
             collect().
             foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
               println(s"($id, $text) --> prob=$prob, prediction=$prediction")
             }
  }

}
